import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import warnings
warnings.filterwarnings("ignore")
from sklearn.preprocessing import MinMaxScaler
from sklearn import metrics
import random
%%capture
!pip install -q hvplot
!pip install pytorch-tabnet
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import scikitplot as skplt
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from xgboost import XGBClassifier
from sklearn.ensemble import RandomForestClassifier,StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import AdaBoostClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from scipy import stats
from numpy import isnan
from sklearn.impute import KNNImputer
from sklearn.model_selection import GridSearchCV, cross_val_score, StratifiedKFold, learning_curve
import pytorch_tabnet
from pytorch_tabnet.tab_model import TabNetClassifier
import torch
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import roc_auc_score, accuracy_score
from sklearn.model_selection import KFold
from matplotlib.pyplot import figure
data=pd.read_csv("https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.cleveland.data",header=None)
data = data.replace("?",np.nan)
data = data.dropna().reset_index(drop=True)
data.columns = ['age', 'sex', 'chest pain type', 'resting bp s', 'cholesterol',
'fasting blood sugar', 'resting ecg', 'max heart rate',
'exercise angina', 'oldpeak', 'ST slope','ca', 'thal', 'target']
k=['age', 'sex', 'chest pain type', 'resting bp s', 'cholesterol',
'fasting blood sugar', 'resting ecg', 'max heart rate',
'exercise angina', 'ST slope','ca', 'thal', 'target']
for j in k:
data[j] = data[j].astype('float').astype('int')
data['oldpeak'] = data['oldpeak'].astype('float')
data['target'] = np.where(data.target>0,1,0)
dataTab = data.copy()
data.head()
| age | sex | chest pain type | resting bp s | cholesterol | fasting blood sugar | resting ecg | max heart rate | exercise angina | oldpeak | ST slope | ca | thal | target | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 63 | 1 | 1 | 145 | 233 | 1 | 2 | 150 | 0 | 2.3 | 3 | 0 | 6 | 0 |
| 1 | 67 | 1 | 4 | 160 | 286 | 0 | 2 | 108 | 1 | 1.5 | 2 | 3 | 3 | 1 |
| 2 | 67 | 1 | 4 | 120 | 229 | 0 | 2 | 129 | 1 | 2.6 | 2 | 2 | 7 | 1 |
| 3 | 37 | 1 | 3 | 130 | 250 | 0 | 0 | 187 | 0 | 3.5 | 3 | 0 | 3 | 0 |
| 4 | 41 | 0 | 2 | 130 | 204 | 0 | 2 | 172 | 0 | 1.4 | 1 | 0 | 3 | 0 |
CP_Dict = {1:'typical angina',2:'atypical angina',3:'non-anginal',4:'asymptomatic'}
ECG_Dict = {0:'normal',1:'ST-T wave abnormality',2:'left ventricular hypertrophy'}
thal_Dict = {3:'normal',6:'fixed defect',7:'reversable defect'}
data.replace({"chest pain type": CP_Dict},inplace=True)
data.replace({"resting ecg": ECG_Dict},inplace=True)
data.replace({"thal": thal_Dict},inplace=True)
data.head()
| age | sex | chest pain type | resting bp s | cholesterol | fasting blood sugar | resting ecg | max heart rate | exercise angina | oldpeak | ST slope | ca | thal | target | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 63 | 1 | typical angina | 145 | 233 | 1 | left ventricular hypertrophy | 150 | 0 | 2.3 | 3 | 0 | fixed defect | 0 |
| 1 | 67 | 1 | asymptomatic | 160 | 286 | 0 | left ventricular hypertrophy | 108 | 1 | 1.5 | 2 | 3 | normal | 1 |
| 2 | 67 | 1 | asymptomatic | 120 | 229 | 0 | left ventricular hypertrophy | 129 | 1 | 2.6 | 2 | 2 | reversable defect | 1 |
| 3 | 37 | 1 | non-anginal | 130 | 250 | 0 | normal | 187 | 0 | 3.5 | 3 | 0 | normal | 0 |
| 4 | 41 | 0 | atypical angina | 130 | 204 | 0 | left ventricular hypertrophy | 172 | 0 | 1.4 | 1 | 0 | normal | 0 |
Sex_Dict = {1:'male',0:'female'}
FS_Dict = {0:'under 120mgdl',1:'over 120mgdl'}
exang_Dict = {0:'not induced',1:'induced'}
slope_Dict = {1:'upsloping',2:'flat',3:'downsloping'}
data.replace({"sex": Sex_Dict},inplace=True)
data.replace({"fasting blood sugar": FS_Dict},inplace=True)
data.replace({"exercise angina": exang_Dict},inplace=True)
data.replace({"ST slope": slope_Dict},inplace=True)
dataset = data.copy()
data.head()
| age | sex | chest pain type | resting bp s | cholesterol | fasting blood sugar | resting ecg | max heart rate | exercise angina | oldpeak | ST slope | ca | thal | target | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 63 | male | typical angina | 145 | 233 | over 120mgdl | left ventricular hypertrophy | 150 | not induced | 2.3 | downsloping | 0 | fixed defect | 0 |
| 1 | 67 | male | asymptomatic | 160 | 286 | under 120mgdl | left ventricular hypertrophy | 108 | induced | 1.5 | flat | 3 | normal | 1 |
| 2 | 67 | male | asymptomatic | 120 | 229 | under 120mgdl | left ventricular hypertrophy | 129 | induced | 2.6 | flat | 2 | reversable defect | 1 |
| 3 | 37 | male | non-anginal | 130 | 250 | under 120mgdl | normal | 187 | not induced | 3.5 | downsloping | 0 | normal | 0 |
| 4 | 41 | female | atypical angina | 130 | 204 | under 120mgdl | left ventricular hypertrophy | 172 | not induced | 1.4 | upsloping | 0 | normal | 0 |
data['target'].value_counts(dropna=False)
0 160 1 137 Name: target, dtype: int64
f, axes = plt.subplots(1, 1, figsize=(4, 6))
sns.countplot(ax=axes,x='target', data=data, palette=['green','orange'])
axes.set_title("Target Distribution", fontsize=20)
Text(0.5, 1.0, 'Target Distribution')
f, axes = plt.subplots(1, 3, figsize=(15, 5))
sns.countplot(ax=axes[0],x='sex', data=data, palette=['green','orange'],hue="target")
axes[0].set_title("sex", fontsize=20)
sns.countplot(ax=axes[1],x='fasting blood sugar', data=data, palette=['green','orange'],hue="target")
axes[1].set_title("fasting blood sugar", fontsize=20)
sns.countplot(ax=axes[2],x='exercise angina', data=data, palette=['green','orange'],hue="target")
plt.title("exercise angina", fontsize=20)
Text(0.5, 1.0, 'exercise angina')
plt.figure(figsize=(12,5))
sns.countplot(x='chest pain type', data=data, palette=['green','orange'],hue="target")
plt.title("chest pain type", fontsize=20)
Text(0.5, 1.0, 'chest pain type')
plt.figure(figsize=(12,5))
sns.countplot(x='ST slope', data=data, palette=['green','orange'],hue="target")
plt.title("ST slope", fontsize=20)
Text(0.5, 1.0, 'ST slope')
plt.figure(figsize=(12,5))
ax=sns.countplot(x='resting ecg', data=data, palette=['green','orange'],hue="target")
plt.title("resting ecg", fontsize=20)
Text(0.5, 1.0, 'resting ecg')
plt.figure(figsize=(12,5))
sns.countplot(x='ca', data=data, palette=['green','orange'],hue="target")
plt.title("ca", fontsize=20)
Text(0.5, 1.0, 'ca')
plt.figure(figsize=(12,5))
sns.countplot(x='thal', data=data, palette=['green','orange'],hue="target")
plt.title("thal", fontsize=20)
Text(0.5, 1.0, 'thal')
data_disease = data[data["target"] == 1]
data_normal = data[data["target"] == 0]
plt.figure(figsize=(8,5))
sns.distplot(data_normal["age"], bins=24, color='g')
sns.distplot(data_disease["age"], bins=24, color='r')
plt.title("Distribuition and density by Age",fontsize=20)
plt.xlabel("Age",fontsize=15)
plt.show()
#figure size
plt.figure(figsize=(8,5))
sns.distplot(data_normal["cholesterol"], bins=24, color='g')
sns.distplot(data_disease["cholesterol"], bins=24, color='r')
plt.title("Distribuition and density by cholesterol",fontsize=20)
plt.xlabel("cholesterol",fontsize=15)
plt.show()
plt.figure(figsize=(8,5))
sns.distplot(data_normal["resting bp s"], bins=24, color='g')
sns.distplot(data_disease["resting bp s"], bins=24, color='r')
plt.title("Distribuition and density by resting bp",fontsize=20)
plt.xlabel("resting bp",fontsize=15)
plt.show()
plt.figure(figsize=(8,5))
sns.distplot(data_normal["max heart rate"], bins=24, color='g')
sns.distplot(data_disease["max heart rate"], bins=24, color='r')
plt.title("Distribuition and density by max heart rate",fontsize=20)
plt.xlabel("max heart rate",fontsize=15)
plt.show()
plt.figure(figsize=(8,5))
sns.distplot(data_normal["oldpeak"], bins=24, color='g')
sns.distplot(data_disease["oldpeak"], bins=24, color='r')
plt.title("Distribuition and density by old peak",fontsize=20)
plt.xlabel("oldpeak",fontsize=15)
plt.show()
plt.figure(figsize=(9, 7))
plt.scatter(data_disease["age"],
data_disease["max heart rate"],
c="salmon")
plt.scatter(data_normal["age"],
data_normal["max heart rate"],
c="lightblue")
plt.title("Heart Disease in function of Age and Max Heart Rate")
plt.xlabel("Age")
plt.ylabel("Max Heart Rate")
plt.legend(["Disease", "No Disease"]);
import hvplot.pandas
data.drop('target', axis=1).corrwith(data.target).hvplot.barh(
width=600, height=400,
title="Correlation between Heart Disease and Numeric Features",
ylabel='Correlation', xlabel='Numerical Features'
)
sns.set(rc = {'figure.figsize':(12,12)})
sns.heatmap(data.corr(), annot = True, fmt='.2g',cmap= 'coolwarm')
<AxesSubplot:>
from sklearn.preprocessing import LabelEncoder
X = dataTab.drop('target', axis = 1)
feats = X.columns
categorical_columns = ['sex','resting ecg', 'chest pain type','fasting blood sugar' ,
'exercise angina','ST slope','thal']
categorical_dims = {}
for col in categorical_columns:
print(col, X[col].nunique())
l_enc = LabelEncoder()
X[col] = l_enc.fit_transform(X[col].values)
categorical_dims[col] = len(l_enc.classes_)
cat_idxs = [ i for i, f in enumerate(feats) if f in categorical_columns]
cat_dims = [ categorical_dims[f] for i, f in enumerate(feats) if f in categorical_columns]
X = np.array(X)
target = np.array(dataTab['target'])
x_train, x_val, y_train, y_val = train_test_split(X, target, test_size=0.3, random_state=42)
x_val, x_test, y_val, y_test = train_test_split(x_val, y_val, test_size=0.4, random_state=42)
sex 2 resting ecg 3 chest pain type 4 fasting blood sugar 2 exercise angina 2 ST slope 3 thal 3
LR = LogisticRegression(solver='liblinear',random_state=42)
RF = RandomForestClassifier(criterion = 'entropy',random_state=42)
XGB = XGBClassifier(max_depth=5,random_state=42)
GBM = GradientBoostingClassifier(learning_rate = 0.01,random_state=42)
models = [LR,RF,XGB,GBM]
metric_list = []
for m in models:
mName = type(m).__name__
m.fit(x_train,y_train)
y_pred = m.predict(x_test)
auroc = np.round(roc_auc_score(y_test, y_pred),4)
accuracy = np.round(accuracy_score(y_test, y_pred),4)
precision = np.round(precision_score(y_test, y_pred),4)
recall = np.round(recall_score(y_test, y_pred),4)
f1 = np.round(f1_score(y_test, y_pred),4)
globals()[f"y_pred_{mName}"] = y_pred
l = [mName,auroc,accuracy,precision,recall,f1]
metric_list.append(l)
print(mName,":",auroc,accuracy,precision,recall,f1)
LogisticRegression : 0.9 0.9167 1.0 0.8 0.8889 RandomForestClassifier : 0.8762 0.8889 0.9231 0.8 0.8571 XGBClassifier : 0.8429 0.8611 0.9167 0.7333 0.8148 GradientBoostingClassifier : 0.7857 0.8056 0.8333 0.6667 0.7407
df_metric_list = pd.DataFrame(metric_list)
df_metric_list.columns = ['modelName','auroc','accuracy','precision','recall','f1_score']
df_metric_list = df_metric_list.sort_values(["accuracy","auroc"],ascending=False).reset_index(drop=True)
df_metric_list
| modelName | auroc | accuracy | precision | recall | f1_score | |
|---|---|---|---|---|---|---|
| 0 | LogisticRegression | 0.9000 | 0.9167 | 1.0000 | 0.8000 | 0.8889 |
| 1 | RandomForestClassifier | 0.8762 | 0.8889 | 0.9231 | 0.8000 | 0.8571 |
| 2 | XGBClassifier | 0.8429 | 0.8611 | 0.9167 | 0.7333 | 0.8148 |
| 3 | GradientBoostingClassifier | 0.7857 | 0.8056 | 0.8333 | 0.6667 | 0.7407 |
font = {'family' : 'normal',
'weight' : 'bold',
'size' : 18}
plt.rc('font', **font)
f, axes = plt.subplots(2, 2, figsize=(8,9))
i=0
axes = axes.ravel()
f.suptitle("Confusion Matrix Base Models", fontsize=20, fontweight='bold')
for m in models:
mName = type(m).__name__
disp = ConfusionMatrixDisplay(confusion_matrix(y_test,globals()[f"y_pred_{mName}"]))
disp.plot(ax=axes[i], values_format='.20g')
axes[i].grid(False)
disp.ax_.set_title(mName,fontweight='bold',fontsize=12)
disp.im_.colorbar.remove()
i=i+1
plt.show()
def set_seed(seed: int = 42):
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(seed)
set_seed()
tabnet = TabNetClassifier(optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=0.01),
scheduler_params={"step_size":100,
"gamma":0.95},
scheduler_fn=torch.optim.lr_scheduler.StepLR,
verbose=0,
mask_type='sparsemax',
cat_dims=cat_dims,cat_idxs=cat_idxs
)
tabnet.fit(
x_train,y_train,
eval_set=[(x_train, y_train), (x_val, y_val)],
eval_name=['train', 'valid'],
eval_metric=['auc','accuracy'],
max_epochs=1000 , patience=50,
batch_size=16,virtual_batch_size=16,
num_workers=0,
drop_last=False
)
y_pred = tabnet.predict(x_test)
test_acc = accuracy_score(y_pred, y_test)
preds_valid = tabnet.predict(x_val)
valid_acc = accuracy_score(preds_valid, y_val)
print("valid_acc:",valid_acc,", test_acc:",test_acc)
Early stopping occurred at epoch 126 with best_epoch = 76 and best_valid_accuracy = 0.85185 valid_acc: 0.8518518518518519 , test_acc: 0.9444444444444444
fig = plt.figure(figsize=(10, 5))
fig.suptitle("Tabnet Training Loss",fontsize=20)
plt.plot(tabnet.history['loss'])
[<matplotlib.lines.Line2D at 0x7f56006777d0>]
fig = plt.figure(figsize=(10, 5))
fig.suptitle("Tabnet Train and Valid Accuracy",fontsize=20)
plt.plot(tabnet.history['train_accuracy'])
plt.plot(tabnet.history['valid_accuracy'])
[<matplotlib.lines.Line2D at 0x7f56005f5190>]
metric_list=[]
auroc = np.round(roc_auc_score(y_test, y_pred),4)
accuracy = np.round(accuracy_score(y_test, y_pred),4)
precision = np.round(precision_score(y_test, y_pred),4)
recall = np.round(recall_score(y_test, y_pred),4)
f1 = np.round(f1_score(y_test, y_pred),4)
l = ['Tabular Net',auroc,accuracy,precision,recall,f1]
metric_list.append(l)
df_metric_list = pd.DataFrame(metric_list)
df_metric_list.columns = ['modelName','auroc','accuracy','precision','recall','f1_score']
df_metric_list = df_metric_list.sort_values(["accuracy","auroc"],ascending=False).reset_index(drop=True)
df_metric_list
| modelName | auroc | accuracy | precision | recall | f1_score | |
|---|---|---|---|---|---|---|
| 0 | Tabular Net | 0.9429 | 0.9444 | 0.9333 | 0.9333 | 0.9333 |
font = {'family' : 'normal',
'weight' : 'bold',
'size' : 18}
plt.rc('font', **font)
f, axes = plt.subplots(1, 1, figsize=(5, 5))
i=0
f.suptitle("Confusion Matrix For Tab Net", fontsize=20, fontweight='bold')
for m in models:
disp = ConfusionMatrixDisplay(confusion_matrix(y_test,y_pred))
disp.plot(ax=axes, values_format='.20g')
axes.grid(False)
disp.im_.colorbar.remove()
i=i+1
plt.show()
feat_importances = pd.Series(tabnet.feature_importances_, index=feats)
feat_importances.nlargest(15).plot(kind='barh')
<AxesSubplot:>
fig = plt.figure(figsize=(5, 7))
importances = tabnet.feature_importances_
indices = np.argsort(importances)[4:]
plt.title('Feature Importances',fontweight='bold',fontsize=22)
plt.barh(range(len(indices)), importances[indices], color='g', align='center')
plt.yticks(range(len(indices)), [feats[i] for i in indices],
fontweight='bold',fontsize=12)
plt.xlabel('Relative Importance')
plt.show()
By using the masks, we can understand which features are being used at a prediction level
font = {'family' : 'normal',
'weight' : 'bold',
'size' : 18}
plt.rc('font', **font)
explain_matrix, masks = tabnet.explain(x_test)
f, axes = plt.subplots(1, 3, figsize=(12,6))
axes = axes.ravel()
f.suptitle("Masks", fontsize=20, fontweight='bold')
for i in range(3):
axes[i].imshow(masks[i])
axes[i].set_title(f"mask {i}")
plt.show()
masksum = masks[0]+masks[1]+masks[2]
masksumdf = pd.DataFrame(masksum)
masksumdf.columns = feats
masksumdf
| age | sex | chest pain type | resting bp s | cholesterol | fasting blood sugar | resting ecg | max heart rate | exercise angina | oldpeak | ST slope | ca | thal | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0.421180 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.239193 | 0.254920 | 0.084707 | 0.000000 | 1.000000 | 0.000000 | 1.000000 | 0.000000 |
| 1 | 0.000000 | 0.032162 | 0.000000 | 1.000000 | 0.0 | 0.000000 | 0.000000 | 0.410161 | 0.000000 | 0.557677 | 0.000000 | 1.000000 | 0.000000 |
| 2 | 0.000000 | 0.000000 | 0.000000 | 0.977114 | 0.0 | 0.000000 | 0.000000 | 0.468335 | 0.000000 | 0.531665 | 0.000000 | 0.022886 | 1.000000 |
| 3 | 0.272394 | 0.031191 | 0.000000 | 0.000000 | 0.0 | 0.288377 | 0.000000 | 0.054172 | 0.000000 | 1.020705 | 0.000000 | 0.893932 | 0.439229 |
| 4 | 0.000000 | 0.866500 | 0.000000 | 0.000000 | 0.0 | 0.000000 | 0.564887 | 0.000000 | 0.042170 | 0.000000 | 0.000000 | 1.000000 | 0.526442 |
| 5 | 0.000000 | 0.000059 | 0.000000 | 0.999941 | 0.0 | 0.000000 | 0.000000 | 0.430287 | 0.000000 | 0.569713 | 0.000000 | 0.000000 | 1.000000 |
| 6 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.925404 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.000000 | 1.000000 | 0.074596 |
| 7 | 0.669275 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.176432 | 0.063401 | 0.090891 | 0.000000 | 0.815062 | 0.000000 | 1.184938 | 0.000000 |
| 8 | 0.000000 | 0.277727 | 0.000000 | 0.467442 | 0.0 | 0.254831 | 0.000000 | 0.477747 | 0.000000 | 0.522253 | 0.000000 | 1.000000 | 0.000000 |
| 9 | 0.091034 | 0.893204 | 0.000000 | 0.000000 | 0.0 | 0.000000 | 0.486976 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.528785 |
| 10 | 0.000000 | 0.052177 | 0.000000 | 0.656777 | 0.0 | 0.000000 | 0.000000 | 0.540997 | 0.000000 | 0.459003 | 0.000000 | 0.291047 | 1.000000 |
| 11 | 0.000000 | 0.000000 | 0.001158 | 0.000000 | 0.0 | 0.884686 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 1.114155 |
| 12 | 0.180344 | 0.469384 | 0.000000 | 0.000000 | 0.0 | 0.255227 | 0.244604 | 0.155433 | 0.000000 | 0.000000 | 0.000000 | 1.419348 | 0.275661 |
| 13 | 0.274975 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.405568 | 0.030806 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.016979 | 1.271672 |
| 14 | 0.181538 | 0.301859 | 0.000000 | 0.000000 | 0.0 | 0.109740 | 0.505199 | 0.147025 | 0.000000 | 0.610345 | 0.000000 | 1.000000 | 0.144294 |
| 15 | 0.251564 | 0.867052 | 0.000000 | 0.000000 | 0.0 | 0.029788 | 0.326736 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.524859 |
| 16 | 0.000000 | 0.515716 | 0.000000 | 0.541488 | 0.0 | 0.000000 | 0.000000 | 0.282020 | 0.000000 | 0.613046 | 0.000000 | 1.047730 | 0.000000 |
| 17 | 0.225985 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.547413 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 1.226603 |
| 18 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.316066 | 0.287177 | 0.304699 | 0.084341 | 0.000000 | 0.000000 | 1.000000 | 1.007718 |
| 19 | 0.000000 | 0.707483 | 0.000000 | 0.000000 | 0.0 | 0.463950 | 0.000000 | 0.000000 | 0.000000 | 0.059420 | 0.000000 | 1.233096 | 0.536050 |
| 20 | 0.515521 | 0.028114 | 0.000000 | 0.000000 | 0.0 | 0.279003 | 0.131515 | 0.000000 | 0.000000 | 0.420060 | 0.000000 | 0.697103 | 0.928685 |
| 21 | 0.000000 | 0.753129 | 0.000000 | 0.136883 | 0.0 | 0.000000 | 0.229975 | 0.000000 | 0.000000 | 0.374388 | 0.000000 | 1.453410 | 0.052215 |
| 22 | 0.329486 | 0.006196 | 0.000000 | 0.000000 | 0.0 | 0.000000 | 0.114499 | 0.010178 | 0.000000 | 0.914524 | 0.000000 | 0.606188 | 1.018930 |
| 23 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 1.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 |
| 24 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.874469 | 0.000000 | 0.125531 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 |
| 25 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.700572 | 0.011820 | 0.287608 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 |
| 26 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.0 | 0.000000 | 0.000000 | 0.505667 | 0.000000 | 0.494333 | 0.000000 | 1.000000 | 0.000000 |
| 27 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.637512 | 0.000000 | 0.362488 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 |
| 28 | 0.138939 | 0.597025 | 0.000000 | 0.000000 | 0.0 | 0.070083 | 0.552135 | 0.000000 | 0.000000 | 0.402975 | 0.000000 | 1.000000 | 0.238843 |
| 29 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.955776 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.044224 | 1.000000 | 1.000000 |
| 30 | 0.000000 | 0.624539 | 0.000000 | 0.100942 | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.098922 | 0.000000 | 0.000000 | 1.632383 | 0.543213 |
| 31 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.0 | 0.000000 | 0.000000 | 0.669707 | 0.000000 | 0.330293 | 0.000000 | 0.000000 | 1.000000 |
| 32 | 0.000000 | 0.000000 | 0.013528 | 0.000000 | 0.0 | 0.986472 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.000000 | 1.000000 | 0.000000 |
| 33 | 0.000000 | 0.385101 | 0.000000 | 0.645728 | 0.0 | 0.000000 | 0.000000 | 0.416254 | 0.000000 | 0.480495 | 0.000000 | 0.072423 | 1.000000 |
| 34 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.0 | 0.000000 | 0.000000 | 0.669707 | 0.000000 | 0.330293 | 0.000000 | 0.000000 | 1.000000 |
| 35 | 0.261785 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.385700 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.062956 | 1.289559 |
font = {'family' : 'normal',
'weight' : 'bold',
'size' : 18}
plt.figure(figsize=(8,6))
plt.rc('font', **font)
sns.heatmap(masksumdf,cbar=True)
plt.rc('font', **font)
plt.title('Features used for prediction',fontweight='bold',fontsize=16)
Text(0.5, 1.0, 'Features used for prediction')